// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0

package mssql

import (
	"fmt"
	"log"
	"time"

	"github.com/hashicorp/go-azure-helpers/lang/response"
	"github.com/hashicorp/go-azure-helpers/resourcemanager/commonids"
	"github.com/hashicorp/go-azure-sdk/resource-manager/sql/2023-08-01-preview/databasevulnerabilityassessmentrulebaselines"
	"github.com/hashicorp/terraform-provider-azurerm/internal/clients"
	"github.com/hashicorp/terraform-provider-azurerm/internal/services/mssql/parse"
	"github.com/hashicorp/terraform-provider-azurerm/internal/services/mssql/validate"
	"github.com/hashicorp/terraform-provider-azurerm/internal/tf/pluginsdk"
	"github.com/hashicorp/terraform-provider-azurerm/internal/tf/validation"
	"github.com/hashicorp/terraform-provider-azurerm/internal/timeouts"
)

func resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaseline() *pluginsdk.Resource {
	return &pluginsdk.Resource{
		Create: resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate,
		Read:   resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineRead,
		Update: resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate,
		Delete: resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineDelete,

		Importer: pluginsdk.ImporterValidatingResourceId(func(id string) error {
			_, err := parse.DatabaseVulnerabilityAssessmentRuleBaselineID(id)
			return err
		}),

		Timeouts: &pluginsdk.ResourceTimeout{
			Create: pluginsdk.DefaultTimeout(30 * time.Minute),
			Read:   pluginsdk.DefaultTimeout(5 * time.Minute),
			Update: pluginsdk.DefaultTimeout(30 * time.Minute),
			Delete: pluginsdk.DefaultTimeout(30 * time.Minute),
		},

		Schema: map[string]*pluginsdk.Schema{
			"server_vulnerability_assessment_id": {
				Type:         pluginsdk.TypeString,
				Required:     true,
				ForceNew:     true,
				ValidateFunc: validate.ServerVulnerabilityAssessmentID,
			},

			"database_name": {
				Type:         pluginsdk.TypeString,
				Required:     true,
				ForceNew:     true,
				ValidateFunc: validate.ValidateMsSqlDatabaseName,
			},

			"rule_id": {
				Type:     pluginsdk.TypeString,
				Required: true,
				ForceNew: true,
			},

			"baseline_name": {
				Type:     pluginsdk.TypeString,
				Optional: true,
				ForceNew: true,
				Default:  string(databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineNameDefault),
				ValidateFunc: validation.StringInSlice([]string{
					string(databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineNameDefault),
					string(databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineNameMaster),
				}, false),
			},

			"baseline_result": {
				Type:     pluginsdk.TypeSet,
				Required: true,
				Elem: &pluginsdk.Resource{
					Schema: map[string]*pluginsdk.Schema{
						"result": {
							Type:     pluginsdk.TypeList,
							Required: true,
							Elem: &pluginsdk.Schema{
								Type:         pluginsdk.TypeString,
								ValidateFunc: validation.StringIsNotEmpty,
							},
						},
					},
				},
			},
		},
	}
}

func resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineCreateUpdate(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	vulnerabilityClient := meta.(*clients.Client).MSSQL.ServerVulnerabilityAssessmentsClient
	ctx, cancel := timeouts.ForCreateUpdate(meta.(*clients.Client).StopContext, d)
	defer cancel()

	log.Printf("[INFO] preparing arguments for Azure ARM Vulnerability Assessment Rule Baselines creation.")

	vulnerabilityAssessmentId, err := parse.ServerVulnerabilityAssessmentID(d.Get("server_vulnerability_assessment_id").(string))
	if err != nil {
		return err
	}

	serverId := commonids.NewSqlServerID(vulnerabilityAssessmentId.SubscriptionId, vulnerabilityAssessmentId.ResourceGroup, vulnerabilityAssessmentId.ServerName)

	vulnerabilityAssessment, err := vulnerabilityClient.Get(ctx, serverId)
	if err != nil {
		return fmt.Errorf("retrieving Server Vulnerability Assessment Settings: %+v", err)
	}
	if vulnerabilityAssessment.Model == nil || vulnerabilityAssessment.Model.Properties == nil || vulnerabilityAssessment.Model.Properties.StorageContainerPath == "" {
		return fmt.Errorf("storage container path not set in Server Vulnerability Assessment Settings")
	}

	// TODO: requires import
	parameters := expandBaselineResults(d.Get("baseline_result").(*pluginsdk.Set))

	// 	subscriptionId, resourceGroup, serverName, databaseName, vulnerabilityAssessmentName, ruleName, baselineName string
	id := parse.NewDatabaseVulnerabilityAssessmentRuleBaselineID(vulnerabilityAssessmentId.SubscriptionId,
		vulnerabilityAssessmentId.ResourceGroup,
		vulnerabilityAssessmentId.ServerName,
		d.Get("database_name").(string),
		vulnerabilityAssessmentId.VulnerabilityAssessmentName,
		d.Get("rule_id").(string),
		d.Get("baseline_name").(string))

	baselineId := databasevulnerabilityassessmentrulebaselines.NewBaselineID(id.SubscriptionId, id.ResourceGroup, id.ServerName, id.DatabaseName, id.RuleName, databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineName(id.BaselineName))

	if _, err := client.CreateOrUpdate(ctx, baselineId, *parameters); err != nil {
		return fmt.Errorf("updating database vulnerability assessment rule baseline: %s", err)
	}

	d.SetId(id.ID())
	return resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineRead(d, meta)
}

func resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineRead(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	ctx, cancel := timeouts.ForRead(meta.(*clients.Client).StopContext, d)
	defer cancel()

	id, err := parse.DatabaseVulnerabilityAssessmentRuleBaselineID(d.Id())
	if err != nil {
		return err
	}

	baselineId := databasevulnerabilityassessmentrulebaselines.NewBaselineID(id.SubscriptionId, id.ResourceGroup, id.ServerName, id.DatabaseName, id.RuleName, databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineName(id.BaselineName))

	result, err := client.Get(ctx, baselineId)
	if err != nil {
		if response.WasNotFound(result.HttpResponse) {
			log.Printf("[WARN] %s was not found", *id)
			d.SetId("")
			return nil
		}

		return fmt.Errorf("retrieving %s: %+v", *id, err)
	}

	d.Set("database_name", id.DatabaseName)
	d.Set("rule_id", id.RuleName)
	d.Set("baseline_name", id.BaselineName)

	vulnerabilityAssessmentId := parse.NewServerVulnerabilityAssessmentID(id.SubscriptionId, id.ResourceGroup, id.ServerName, id.VulnerabilityAssessmentName)
	d.Set("server_vulnerability_assessment_id", vulnerabilityAssessmentId.ID())

	if model := result.Model; model != nil {
		if props := model.Properties; props != nil {
			d.Set("baseline_result", flattenBaselineResult(props.BaselineResults))
		}
	}

	return nil
}

func resourceMsSqlDatabaseVulnerabilityAssessmentRuleBaselineDelete(d *pluginsdk.ResourceData, meta interface{}) error {
	client := meta.(*clients.Client).MSSQL.DatabaseVulnerabilityAssessmentRuleBaselinesClient
	ctx, cancel := timeouts.ForDelete(meta.(*clients.Client).StopContext, d)
	defer cancel()

	id, err := parse.DatabaseVulnerabilityAssessmentRuleBaselineID(d.Id())
	if err != nil {
		return err
	}

	baselineId := databasevulnerabilityassessmentrulebaselines.NewBaselineID(id.SubscriptionId, id.ResourceGroup, id.ServerName, id.DatabaseName, id.RuleName, databasevulnerabilityassessmentrulebaselines.VulnerabilityAssessmentPolicyBaselineName(id.BaselineName))

	if _, err := client.Delete(ctx, baselineId); err != nil {
		return fmt.Errorf("deleting %s: %+v", *id, err)
	}

	return nil
}

func expandBaselineResults(baselineResult *pluginsdk.Set) *databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaseline {
	baselineResultList := baselineResult.List()

	baselineResults := make([]databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaselineItem, len(baselineResultList))

	for i, baselineResult := range baselineResultList {
		result := make([]string, 0)
		baselineResultMap := baselineResult.(map[string]interface{})

		for _, s := range baselineResultMap["result"].([]interface{}) {
			result = append(result, s.(string))
		}

		baselineResults[i] = databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaselineItem{
			Result: result,
		}
	}

	return &databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaseline{
		Properties: &databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaselineProperties{
			BaselineResults: baselineResults,
		},
	}
}

func flattenBaselineResult(baselineResults []databasevulnerabilityassessmentrulebaselines.DatabaseVulnerabilityAssessmentRuleBaselineItem) []map[string]interface{} {
	resp := make([]map[string]interface{}, 0)

	for _, baselineResult := range baselineResults {
		output := map[string]interface{}{}

		if result := baselineResult.Result; result != nil {
			output["result"] = result
		}

		resp = append(resp, output)
	}

	return resp
}
